# gin_net.py
import torch
from torch import nn
import torch.optim as optim
from dgl.nn.pytorch.conv import GINConv

class GINNet(nn.Module):
    """
    Two-layer GIN: sum-aggregator, in_feats→hidden_feats→1
    """
    def __init__(self,
                 in_feats,
                 hidden_feats,
                 n_steps,
                 lr,
                 early_stop=10,
                 activation=torch.sigmoid,
                 aggregator_type='sum',
                 init_eps=0.0,
                 learn_eps=False):
        super().__init__()
        apply1 = nn.Sequential(
            nn.Linear(in_feats, hidden_feats),
            nn.ReLU(),
            nn.Linear(hidden_feats, hidden_feats)
        )
        self.conv1 = GINConv(
            apply_func=apply1,
            aggregator_type=aggregator_type,
            init_eps=init_eps,
            learn_eps=learn_eps
        )
        apply2 = nn.Sequential(
            nn.Linear(hidden_feats, 1)
        )
        self.conv2 = GINConv(
            apply_func=apply2,
            aggregator_type=aggregator_type,
            init_eps=init_eps,
            learn_eps=learn_eps
        )

        self.n_steps    = n_steps
        self.early_stop = early_stop
        self.optimizer  = optim.Adam(self.parameters(), lr=lr)
        self.activation = activation
        self.device     = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
        self.to(self.device)

    def forward(self, g, feats):
        h = self.conv1(g, feats.to(self.device))
        h = torch.relu(h)
        h = self.conv2(g, h)
        return self.activation(h.squeeze(-1))

    def train(self, g, edge_map, train_idx, y_train, valid_idx, y_valid, verbose=False):
        loss_fn      = nn.MSELoss()
        valid_losses = []

        for epoch in range(self.n_steps):
            self.optimizer.zero_grad()
            out = self.forward(g, g.ndata['feat'])
            loss_t = loss_fn(out[train_idx], y_train.to(self.device))
            loss_v = loss_fn(out[valid_idx], y_valid.to(self.device))
            loss_t.backward()
            self.optimizer.step()

            valid_losses.append(loss_v.item())
            if verbose and epoch % 1000 == 0:
                print(f"epoch {epoch}  train={loss_t.item():.4f}  valid={loss_v.item():.4f}")

            if epoch > self.early_stop and valid_losses[-1] > sum(valid_losses[-(self.early_stop+1):-1]) / self.early_stop:
                if verbose: print("Early stopping.")
                break
